{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simple Convnet - MNIST\n",
    "\n",
    "Slightly modified from mnist_cnn.py in the Keras examples folder:\n",
    "\n",
    "https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "KERAS_MODEL_FILEPATH = '../../demos/data/mnist_cnn/mnist_cnn.h5'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/home/leon/miniconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "np.random.seed(1337)  # for reproducibility\n",
    "\n",
    "from keras.datasets import mnist\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Dropout, Flatten, Activation\n",
    "from keras.layers import Conv2D, MaxPooling2D\n",
    "from keras.utils import np_utils\n",
    "from keras.callbacks import EarlyStopping, ModelCheckpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train shape: (60000, 28, 28, 1)\n",
      "60000 train samples\n",
      "10000 test samples\n"
     ]
    }
   ],
   "source": [
    "num_classes = 10\n",
    "\n",
    "# input image dimensions\n",
    "img_rows, img_cols = 28, 28\n",
    "\n",
    "# the data, shuffled and split between train and test sets\n",
    "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
    "\n",
    "x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)\n",
    "x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)\n",
    "input_shape = (img_rows, img_cols, 1)\n",
    "\n",
    "x_train = x_train.astype('float32')\n",
    "x_test = x_test.astype('float32')\n",
    "x_train /= 255\n",
    "x_test /= 255\n",
    "print('x_train shape:', x_train.shape)\n",
    "print(x_train.shape[0], 'train samples')\n",
    "print(x_test.shape[0], 'test samples')\n",
    "\n",
    "# convert class vectors to binary class matrices\n",
    "y_train = np_utils.to_categorical(y_train, num_classes)\n",
    "y_test = np_utils.to_categorical(y_test, num_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sequential Model\n",
    "\n",
    "model = Sequential()\n",
    "model.add(Conv2D(32, (3, 3), input_shape=input_shape))\n",
    "model.add(Activation('relu'))\n",
    "model.add(Conv2D(32, (3, 3)))\n",
    "model.add(Activation('relu'))\n",
    "model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "model.add(Dropout(0.25))\n",
    "model.add(Flatten())\n",
    "model.add(Dense(128))\n",
    "model.add(Activation('relu'))\n",
    "model.add(Dropout(0.5))\n",
    "model.add(Dense(num_classes))\n",
    "model.add(Activation('softmax'))\n",
    "\n",
    "model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 60000 samples, validate on 10000 samples\n",
      "Epoch 1/100\n",
      "Epoch 00001: val_acc improved from -inf to 0.98270, saving model to mnist_cnn.h5\n",
      " - 4s - loss: 0.2833 - acc: 0.9127 - val_loss: 0.0557 - val_acc: 0.9827\n",
      "Epoch 2/100\n",
      "Epoch 00002: val_acc improved from 0.98270 to 0.98700, saving model to mnist_cnn.h5\n",
      " - 2s - loss: 0.0980 - acc: 0.9708 - val_loss: 0.0407 - val_acc: 0.9870\n",
      "Epoch 3/100\n",
      "Epoch 00003: val_acc improved from 0.98700 to 0.98870, saving model to mnist_cnn.h5\n",
      " - 2s - loss: 0.0747 - acc: 0.9777 - val_loss: 0.0344 - val_acc: 0.9887\n",
      "Epoch 4/100\n",
      "Epoch 00004: val_acc improved from 0.98870 to 0.98960, saving model to mnist_cnn.h5\n",
      " - 2s - loss: 0.0611 - acc: 0.9811 - val_loss: 0.0309 - val_acc: 0.9896\n",
      "Epoch 5/100\n",
      "Epoch 00005: val_acc did not improve\n",
      " - 2s - loss: 0.0542 - acc: 0.9836 - val_loss: 0.0329 - val_acc: 0.9888\n",
      "Epoch 6/100\n",
      "Epoch 00006: val_acc improved from 0.98960 to 0.99140, saving model to mnist_cnn.h5\n",
      " - 2s - loss: 0.0483 - acc: 0.9853 - val_loss: 0.0289 - val_acc: 0.9914\n",
      "Epoch 7/100\n",
      "Epoch 00007: val_acc did not improve\n",
      " - 2s - loss: 0.0432 - acc: 0.9863 - val_loss: 0.0278 - val_acc: 0.9912\n",
      "Epoch 8/100\n",
      "Epoch 00008: val_acc did not improve\n",
      " - 2s - loss: 0.0398 - acc: 0.9873 - val_loss: 0.0295 - val_acc: 0.9902\n",
      "Epoch 9/100\n",
      "Epoch 00009: val_acc did not improve\n",
      " - 2s - loss: 0.0351 - acc: 0.9885 - val_loss: 0.0277 - val_acc: 0.9910\n",
      "Epoch 10/100\n",
      "Epoch 00010: val_acc did not improve\n",
      " - 2s - loss: 0.0333 - acc: 0.9897 - val_loss: 0.0281 - val_acc: 0.9910\n",
      "Epoch 11/100\n",
      "Epoch 00011: val_acc improved from 0.99140 to 0.99210, saving model to mnist_cnn.h5\n",
      " - 2s - loss: 0.0322 - acc: 0.9893 - val_loss: 0.0258 - val_acc: 0.9921\n",
      "Epoch 12/100\n",
      "Epoch 00012: val_acc did not improve\n",
      " - 2s - loss: 0.0287 - acc: 0.9906 - val_loss: 0.0273 - val_acc: 0.9918\n",
      "Epoch 13/100\n",
      "Epoch 00013: val_acc did not improve\n",
      " - 2s - loss: 0.0258 - acc: 0.9915 - val_loss: 0.0291 - val_acc: 0.9918\n",
      "Epoch 14/100\n",
      "Epoch 00014: val_acc did not improve\n",
      " - 2s - loss: 0.0247 - acc: 0.9916 - val_loss: 0.0281 - val_acc: 0.9913\n",
      "Epoch 15/100\n",
      "Epoch 00015: val_acc improved from 0.99210 to 0.99270, saving model to mnist_cnn.h5\n",
      " - 2s - loss: 0.0245 - acc: 0.9922 - val_loss: 0.0279 - val_acc: 0.9927\n",
      "Epoch 16/100\n",
      "Epoch 00016: val_acc did not improve\n",
      " - 2s - loss: 0.0235 - acc: 0.9921 - val_loss: 0.0312 - val_acc: 0.9916\n",
      "Epoch 17/100\n",
      "Epoch 00017: val_acc did not improve\n",
      " - 2s - loss: 0.0228 - acc: 0.9927 - val_loss: 0.0297 - val_acc: 0.9918\n",
      "Epoch 18/100\n",
      "Epoch 00018: val_acc did not improve\n",
      " - 2s - loss: 0.0208 - acc: 0.9931 - val_loss: 0.0308 - val_acc: 0.9917\n",
      "Epoch 19/100\n",
      "Epoch 00019: val_acc did not improve\n",
      " - 2s - loss: 0.0203 - acc: 0.9930 - val_loss: 0.0297 - val_acc: 0.9918\n",
      "Epoch 20/100\n",
      "Epoch 00020: val_acc did not improve\n",
      " - 2s - loss: 0.0179 - acc: 0.9940 - val_loss: 0.0307 - val_acc: 0.9920\n",
      "Epoch 00020: early stopping\n",
      "Test score: 0.0306938055278\n",
      "Test accuracy: 0.992\n"
     ]
    }
   ],
   "source": [
    "# Model saving callback\n",
    "checkpointer = ModelCheckpoint(filepath=KERAS_MODEL_FILEPATH, monitor='val_acc', verbose=1, save_best_only=True)\n",
    "\n",
    "# Early stopping\n",
    "early_stopping = EarlyStopping(monitor='val_acc', verbose=1, patience=5)\n",
    "\n",
    "# Train\n",
    "batch_size = 128\n",
    "epochs = 100\n",
    "model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=2,\n",
    "          callbacks=[checkpointer, early_stopping], \n",
    "          validation_data=(x_test, y_test))\n",
    "score = model.evaluate(x_test, y_test, verbose=0)\n",
    "print('Test score:', score[0])\n",
    "print('Test accuracy:', score[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
